{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/2_train_dagger.ipynb)\n",
    "# Train an Agent using the DAgger Algorithm\n",
    "\n",
    "The DAgger algorithm is an extension of behavior cloning. \n",
    "In behavior cloning, the training trajectories are recorded directly from an expert.\n",
    "In DAgger, the learner generates the trajectories but an expert corrects the actions with the optimal actions in each of the visited states.\n",
    "This ensures that the state distribution of the training data matches that of the learner's current policy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need an expert to learn from. For convenience we download one from the HuggingFace model hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "from imitation.policies.serialize import load_policy\n",
    "from imitation.util.util import make_vec_env\n",
    "\n",
    "env = make_vec_env(\n",
    "    \"seals:seals/CartPole-v0\",\n",
    "    rng=np.random.default_rng(),\n",
    "    n_envs=1,\n",
    ")\n",
    "expert = load_policy(\n",
    "    \"ppo-huggingface\",\n",
    "    organization=\"HumanCompatibleAI\",\n",
    "    env_name=\"seals/CartPole-v0\",\n",
    "    venv=env,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can construct a DAgger trainer und use it to train the policy on the cartpole environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "\n",
    "from imitation.algorithms import bc\n",
    "from imitation.algorithms.dagger import SimpleDAggerTrainer\n",
    "\n",
    "bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    rng=np.random.default_rng(),\n",
    ")\n",
    "\n",
    "with tempfile.TemporaryDirectory(prefix=\"dagger_example_\") as tmpdir:\n",
    "    print(tmpdir)\n",
    "    dagger_trainer = SimpleDAggerTrainer(\n",
    "        venv=env,\n",
    "        scratch_dir=tmpdir,\n",
    "        expert_policy=expert,\n",
    "        bc_trainer=bc_trainer,\n",
    "        rng=np.random.default_rng(),\n",
    "    )\n",
    "\n",
    "    dagger_trainer.train(2000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, the evaluation shows, that we actually trained a policy that solves the environment (500 is the max reward)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "reward, _ = evaluate_policy(dagger_trainer.policy, env, 20)\n",
    "print(reward)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
